{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is to prove the inconsistency of NID method.\n",
    "\n",
    "We fisrt run the NID code and then compare it with our method.\n",
    "\n",
    "The code will run for a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from neural_interaction_detection import get_interactions\n",
    "from multilayer_perceptron import MLP, train, get_weights\n",
    "from utils import preprocess_data, get_pairwise_auc, get_anyorder_R_precision, set_seed, print_rankings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_main_effect_nets = True # toggle this to use \"main effect\" nets\n",
    "num_samples = 30000\n",
    "num_features = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate synthetic data with ground truth interactions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def synth_func(X):\n",
    "    X1, X2, X3, X4, X5, X6, X7, X8 = X.transpose()\n",
    "    interaction1 = X1**2*X2*X3**2*X4                        \n",
    "    interaction2 = X5**2*X6*X7**2*X8\n",
    "\n",
    "    Y =              interaction1 + interaction2 \n",
    "    ground_truth = [     {1,2,3,4},        {5,6,7,8}     ]\n",
    "    \n",
    "    return Y, ground_truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(42)\n",
    "X = np.random.uniform(low=-1, high=1, size=(num_samples,num_features))\n",
    "Y, ground_truth = synth_func(X)\n",
    "data_loaders = preprocess_data(X, Y, valid_size=10000, test_size=10000, std_scale=True, get_torch_loaders=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a multilayer perceptron (MLP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "model = MLP(num_features, [140, 100, 60, 20], use_main_effect_nets=use_main_effect_nets).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting to train\n",
      "early stopping enabled\n",
      "[epoch 1, total 100] train loss: 0.8319, val loss: 0.8109\n",
      "[epoch 3, total 100] train loss: 0.5896, val loss: 0.6396\n",
      "[epoch 5, total 100] train loss: 0.2602, val loss: 0.2320\n",
      "[epoch 7, total 100] train loss: 0.1517, val loss: 0.2999\n",
      "[epoch 9, total 100] train loss: 0.1350, val loss: 0.1484\n",
      "[epoch 11, total 100] train loss: 0.0774, val loss: 0.1224\n",
      "[epoch 13, total 100] train loss: 0.0862, val loss: 0.1022\n",
      "[epoch 15, total 100] train loss: 0.0766, val loss: 0.0910\n",
      "[epoch 17, total 100] train loss: 0.0715, val loss: 0.0839\n",
      "[epoch 19, total 100] train loss: 0.0511, val loss: 0.0757\n",
      "[epoch 21, total 100] train loss: 0.0790, val loss: 0.1116\n",
      "[epoch 23, total 100] train loss: 0.0698, val loss: 0.1030\n",
      "[epoch 25, total 100] train loss: 0.0740, val loss: 0.0621\n",
      "[epoch 27, total 100] train loss: 0.0696, val loss: 0.0835\n",
      "[epoch 29, total 100] train loss: 0.0487, val loss: 0.0788\n",
      "[epoch 31, total 100] train loss: 0.0639, val loss: 0.0875\n",
      "[epoch 33, total 100] train loss: 0.0595, val loss: 0.0762\n",
      "[epoch 35, total 100] train loss: 0.0613, val loss: 0.0729\n",
      "early stopping!\n",
      "Finished Training. Test loss:  0.05183456093072891\n"
     ]
    }
   ],
   "source": [
    "model, mlp_loss = train(model, data_loaders, device=device, learning_rate=1e-2, l1_const = 5e-5, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the MLP's learned weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_weights = get_weights(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Detect interactions from the weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[((2, 4), 70.7305),\n",
       " ((6, 8), 61.814903),\n",
       " ((5, 6), 35.181442),\n",
       " ((6, 7), 33.913887),\n",
       " ((5, 7), 29.893131),\n",
       " ((5, 8), 22.185238),\n",
       " ((2, 8), 20.204613),\n",
       " ((2, 6), 19.646027),\n",
       " ((7, 8), 18.416555),\n",
       " ((4, 8), 18.112186),\n",
       " ((4, 6), 17.493275),\n",
       " ((1, 6), 15.431152),\n",
       " ((1, 3), 15.218493),\n",
       " ((2, 7), 14.409973),\n",
       " ((1, 5), 13.93491),\n",
       " ((3, 6), 13.413807),\n",
       " ((4, 5), 13.235764),\n",
       " ((2, 3), 13.13788),\n",
       " ((1, 8), 12.991602),\n",
       " ((1, 2), 12.955213),\n",
       " ((2, 5), 12.950325),\n",
       " ((1, 7), 12.683628),\n",
       " ((3, 8), 12.498389),\n",
       " ((4, 7), 12.068602),\n",
       " ((3, 7), 12.059618),\n",
       " ((3, 5), 11.738636),\n",
       " ((1, 4), 11.333738),\n",
       " ((3, 4), 10.039406)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pairwise_interactions = get_interactions(model_weights, pairwise=True, one_indexed=True)\n",
    "pairwise_interactions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Detect interactions with our principled method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start dectecting\n",
      "Initialization done! initial try3times\n",
      "chosen arm: 0 strength: 0.9775244333012112 iteration: 177\n",
      "chosen arm: 7 strength: 0.6998232301504943 iteration: 208\n",
      "chosen arm: 2 strength: 0.8073215459666722 iteration: 275\n",
      "chosen arm: 13 strength: 0.6627016655325304 iteration: 338\n",
      "chosen arm: 27 strength: 0.8149593249545433 iteration: 359\n",
      "chosen arm: 25 strength: 0.40065102562099736 iteration: 385\n",
      "chosen arm: 24 strength: 5.1987242418581445 iteration: 423\n",
      "chosen arm: 1 strength: 0.09203294216400458 iteration: 477\n",
      "chosen arm: 22 strength: 0.40091696214383576 iteration: 487\n",
      "chosen arm: 26 strength: 0.16952629989439932 iteration: 493\n",
      "chosen arm: 23 strength: 0.09347923261619893 iteration: 524\n",
      "chosen arm: 8 strength: 0.02082446837529359 iteration: 620\n",
      "chosen arm: 16 strength: 0.010436514839125452 iteration: 681\n",
      "chosen arm: 5 strength: 0.00732620194638539 iteration: 773\n",
      "chosen arm: 9 strength: 0.004539545773195402 iteration: 858\n",
      "chosen arm: 11 strength: 0.009087912594623049 iteration: 867\n",
      "chosen arm: 4 strength: 0.003653592576301131 iteration: 1014\n",
      "chosen arm: 15 strength: 0.004199234304272701 iteration: 1101\n",
      "chosen arm: 3 strength: 0.00941262186027719 iteration: 1112\n",
      "chosen arm: 18 strength: 0.003248432181806985 iteration: 1156\n",
      "time: 9.656995296478271\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([array([0, 1]),\n",
       "  array([1, 2]),\n",
       "  array([0, 3]),\n",
       "  array([2, 3]),\n",
       "  array([6, 7]),\n",
       "  array([5, 6]),\n",
       "  array([4, 7]),\n",
       "  array([0, 2]),\n",
       "  array([4, 5]),\n",
       "  array([5, 7]),\n",
       "  array([4, 6]),\n",
       "  array([1, 3]),\n",
       "  array([2, 6]),\n",
       "  array([0, 6]),\n",
       "  array([1, 4]),\n",
       "  array([1, 6]),\n",
       "  array([0, 5]),\n",
       "  array([2, 5]),\n",
       "  array([0, 4]),\n",
       "  array([3, 4])],\n",
       " 9.656995296478271)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we use the exact same model trained above\n",
    "from UCBtools import *\n",
    "X_torch = torch.Tensor(X)\n",
    "UCB_interactions=detect_Hessian_UCB(model,X_torch,20) \n",
    "(UCB_interactions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[((1, 2), 19),\n",
       " ((2, 3), 18),\n",
       " ((1, 4), 17),\n",
       " ((3, 4), 16),\n",
       " ((7, 8), 15),\n",
       " ((6, 7), 14),\n",
       " ((5, 8), 13),\n",
       " ((1, 3), 12),\n",
       " ((5, 6), 11),\n",
       " ((6, 8), 10),\n",
       " ((5, 7), 9),\n",
       " ((2, 4), 8),\n",
       " ((3, 7), 7),\n",
       " ((1, 7), 6),\n",
       " ((2, 5), 5),\n",
       " ((2, 7), 4),\n",
       " ((1, 6), 3),\n",
       " ((3, 6), 2),\n",
       " ((1, 5), 1),\n",
       " ((4, 5), 0)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transfer pairwise interaction format only\n",
    "our_pairwise_interactions=[]\n",
    "c=range(len(UCB_interactions[0]))\n",
    "d=sorted(c, reverse=True)\n",
    "for i in c:\n",
    "    our_pairwise_interactions.append(((UCB_interactions[0][i][0]+1,UCB_interactions[0][i][1]+1),d[i]))\n",
    "(our_pairwise_interactions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the interactions (NID & Our method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NID Pairwise AUC 0.6979166666666666\n",
      "Our Pairwise AUC 1.0\n"
     ]
    }
   ],
   "source": [
    "NIDauc = get_pairwise_auc(pairwise_interactions, ground_truth)\n",
    "ourauc = get_pairwise_auc(our_pairwise_interactions, ground_truth)\n",
    "print(\"NID Pairwise AUC\", NIDauc)\n",
    "print(\"Our Pairwise AUC\", ourauc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml",
   "language": "python",
   "name": "ml"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
